package main

import (
	"crypto/sha1"
	"encoding/hex"
	"encoding/json"
	"encoding/xml"
	"errors"
	"fmt"
	"io/ioutil"
	"net/http"
	"os"
	"sort"
	"strings"
	"time"

	"github.com/gorilla/websocket"
)

var userChatHistory = make(map[string][]Message)
var userLastChatTime = make(map[string]int64)

// 接口响应超过4s时缓存的答案
var userStashMsg = make(map[string]string)

// 用户有没有进行中的问答标识，防止用户请求到不完整的答案
var userHasAnswerIng = make(map[string]bool)

type XMLMessage struct {
	ToUserName   string `xml:"ToUserName"`
	FromUserName string `xml:"FromUserName"`
	CreateTime   int64  `xml:"CreateTime"`
	MsgType      string `xml:"MsgType"`
	MsgId        string `xml:"MsgId"`
	MsgDataId    string `xml:"MsgDataId"`
	Idx          string `xml:"Idx"`
	Event        string `xml:"Event"`
}

type TextMessage struct {
	XMLMessage
	Content string `xml:"Content"`
}

var textReplayTemplate = `
<xml>
  <ToUserName><![CDATA[%s]]></ToUserName>
  <FromUserName><![CDATA[%s]]></FromUserName>
  <CreateTime>%d</CreateTime>
  <MsgType><![CDATA[text]]></MsgType>
  <Content><![CDATA[%s]]></Content>
</xml>
`
var ImgReplayTemplate = `
<xml>
  <ToUserName><![CDATA[%s]]></ToUserName>
  <FromUserName><![CDATA[%s]]></FromUserName>
  <CreateTime>%d</CreateTime>
  <MsgType><![CDATA[image]]></MsgType>
  <Image>
    <MediaId><![CDATA[%s]]></MediaId>
  </Image>
</xml>
`

func handlePostRequest(w http.ResponseWriter, r *http.Request) {
	//对此路由的get请求是微信公众号后台对token的校验
	if r.Method == http.MethodGet {
		timestamp := r.FormValue("timestamp")
		nonce := r.FormValue("nonce")
		token := os.Getenv("WX_TOKEN")
		tmpArr := []string{token, timestamp, nonce}
		sort.Strings(tmpArr)
		tmpStr := strings.Join(tmpArr, "")
		hash := sha1.Sum([]byte(tmpStr))
		sign := hex.EncodeToString(hash[:])
		if sign == r.FormValue("signature") {
			w.WriteHeader(http.StatusOK)
			w.Write([]byte(r.FormValue("echostr")))
		} else {
			w.WriteHeader(http.StatusInternalServerError)
			w.Write([]byte("failed"))
		}
		return
	}
	// 读取请求体数据
	body, err := ioutil.ReadAll(r.Body)
	fmt.Println(string(body))
	if err != nil {
		http.Error(w, "Failed to read request body", http.StatusBadRequest)
		return
	}
	// 解析XML数据流
	var textMsg = TextMessage{}
	err = xml.Unmarshal(body, &textMsg)
	if err != nil {
		http.Error(w, "Failed to parse XML", http.StatusBadRequest)
		return
	}
	timeNow := time.Now().Unix()
	//如果是关注或者取关事件也直接回复
	if textMsg.MsgType == "event" {
		//取关不处理
		if textMsg.Event == "subscribe" {
			w.WriteHeader(http.StatusOK)
			msg := os.Getenv("SUBSCRIBE_REPLY")
			_, _ = w.Write([]byte(fmt.Sprintf(textReplayTemplate, textMsg.FromUserName, textMsg.ToUserName, timeNow, msg)))
			return
		} else {
			w.WriteHeader(http.StatusOK)
			_, _ = w.Write([]byte(""))
			return
		}
	}
	fmt.Println("收到消息：", textMsg.Content)
	//解析关键词自动回复
	keyword := os.Getenv("KEYWORD_REPLY")
	keywordObj := make(map[string]string)
	_ = json.Unmarshal([]byte(keyword), &keywordObj)
	if reply, ok := keywordObj[textMsg.Content]; ok {
		msg := reply
		content := ""
		w.WriteHeader(http.StatusOK)
		if strings.Contains(keywordObj[textMsg.Content], "img:") {
			//图片回复
			mediaId := strings.Split(keywordObj[textMsg.Content], ":")[1]
			content = fmt.Sprintf(ImgReplayTemplate, textMsg.FromUserName, textMsg.ToUserName, timeNow, mediaId)
		} else {
			//文本回复
			content = fmt.Sprintf(textReplayTemplate, textMsg.FromUserName, textMsg.ToUserName, timeNow, msg)
		}
		_, _ = w.Write([]byte(content))
		return
	}
	if ing, ok := userHasAnswerIng[textMsg.FromUserName]; ok && ing {
		w.WriteHeader(http.StatusOK)
		msg := "我还在思考中，耐心一点啦，但是我一定会回复你的，我保证。所以请继续回复任意文字尝试获取回复。比如数字 1。"
		_, _ = w.Write([]byte(fmt.Sprintf(textReplayTemplate, textMsg.FromUserName, textMsg.ToUserName, timeNow, msg)))
		return
	}
	//如果用户是为了取出暂存的消息，我们直接返回就好了
	if m, ok := userStashMsg[textMsg.FromUserName]; ok {
		fmt.Println("用户有暂存数据，返回暂存数据")
		w.WriteHeader(http.StatusOK)
		w.Write([]byte(fmt.Sprintf(textReplayTemplate, textMsg.FromUserName, textMsg.ToUserName, timeNow, m)))
		delete(userStashMsg, textMsg.FromUserName)
		return
	}
	//获取连接
	connect, err := getConnect()
	if err != nil {
		fmt.Println(err.Error())
	}
	if t, ok := userLastChatTime[textMsg.FromUserName]; ok && timeNow-t >= 300 {
		//用户有历史会话，且超过5分钟，则清除数据
		userChatHistory[textMsg.FromUserName] = []Message{}
	}
	userLastChatTime[textMsg.FromUserName] = timeNow
	userChatHistory[textMsg.FromUserName] = append(userChatHistory[textMsg.FromUserName], Message{
		Role:    "user",
		Content: textMsg.Content,
	})
	data := genParams(userChatHistory[textMsg.FromUserName])
	connect.WriteJSON(data)

	var answer = ""
	done := make(chan bool)
	go func() {
		for {
			_, msg, err := connect.ReadMessage()
			if err != nil {
				fmt.Println("read message error:", err)
				break
			}

			var data map[string]interface{}
			err1 := json.Unmarshal(msg, &data)
			if err1 != nil {
				fmt.Println("Error parsing JSON:", err)
				return
			}
			fmt.Println(string(msg))
			header := data["header"].(map[string]interface{})
			code := header["code"].(float64)

			if code != 0 {
				fmt.Println("未从星火获得结果：", data)
				return
			}
			//解析数据
			payload := data["payload"].(map[string]interface{})
			choices := payload["choices"].(map[string]interface{})
			status := choices["status"].(float64)
			text := choices["text"].([]interface{})
			content := text[0].(map[string]interface{})["content"].(string)
			if status != 2 {
				answer += content
			} else {
				fmt.Println("收到最终结果")
				answer += content
				usage := payload["usage"].(map[string]interface{})
				temp := usage["text"].(map[string]interface{})
				totalTokens := temp["total_tokens"].(float64)
				fmt.Println("total_tokens:", totalTokens)
				userHasAnswerIng[textMsg.FromUserName] = false
				userChatHistory[textMsg.FromUserName] = append(userChatHistory[textMsg.FromUserName], Message{
					Role:    "assistant",
					Content: answer,
				})
				timeNow2 := time.Now().Unix()
				if timeNow2-timeNow > 4 {
					//执行时间超过4s的回复，暂存，等用户取用
					userStashMsg[textMsg.FromUserName] = answer
				}
				break
			}

		}
		done <- true
	}()
	select {
	case <-time.After(4 * time.Second):
		userHasAnswerIng[textMsg.FromUserName] = true
		//5s未回复微信服务器，微信公众号会提示当前服务不可用
		//执行超过4s时，延时器会先返回给微信结果
		fmt.Println("执行超过4s，提前返回")
		w.WriteHeader(http.StatusOK)
		msg := "微信规定要在5s内回复，但是我正在思考中，所以你暂时看到了这条消息。请稍后回复任意文字尝试获取回复。比如数字 1。"
		_, _ = w.Write([]byte(fmt.Sprintf(textReplayTemplate, textMsg.FromUserName, textMsg.ToUserName, timeNow, msg)))
	case <-done:
		w.WriteHeader(http.StatusOK)
		reply := fmt.Sprintf(textReplayTemplate, textMsg.FromUserName, textMsg.ToUserName, time.Now().Unix(), answer)
		_, _ = w.Write([]byte(reply))

	}
}
func getConnect() (*websocket.Conn, error) {
	d := websocket.Dialer{
		HandshakeTimeout: 5 * time.Second,
	}
	//握手并建立websocket 连接
	conn, resp, err := d.Dial(assembleAuthUrl1(), nil)
	if err != nil || resp.StatusCode != 101 {
		return conn, errors.New("连接到讯飞星火大模型失败" + err.Error())
	}
	return conn, nil
}
func main() {
	go func() {
		//定时清理一下变量，防止内存泄漏
		ticker := time.NewTicker(time.Minute * 2)
		for {
			select {
			case <-ticker.C:
				fmt.Println("清理内存")
				for s, b := range userHasAnswerIng {
					if !b {
						delete(userHasAnswerIng, s)
					}
				}
				timeNow := time.Now().Unix()
				//因为5分钟就会算新会话，我们把上次聊天大于6分钟的，都清理掉
				for s, i := range userLastChatTime {
					if timeNow-i > 360 {
						delete(userLastChatTime, s)
					}
				}
			}

		}
	}()
	http.HandleFunc("/wx", handlePostRequest)
	fmt.Println("程序启动，运行在" + os.Getenv("SERVER_PORT") + "端口")
	http.ListenAndServe(":"+os.Getenv("SERVER_PORT"), nil)
}
